Add preliminary Muon+M-FSDP support#4486
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Route the emerging-optimizer factory through a Megatron-FSDP-specific
path when `ddp_config.use_megatron_fsdp` is set. Megatron-FSDP attaches
grads via `finish_grad_sync()` on DTensor params instead of via DDP's
main_grad buffers, so the standard `Float16OptimizerWithFloat16Params`
wrapper does not apply; we always wrap with `FP32Optimizer` instead and
drive the FSDP step contract from a thin `FSDPMuonChainedOptimizer`
adapter that calls `finish_grad_sync()` and
`install_optimized_model_weights()` around the inner step.
For now this supports ZeRO-0 ("no_shard") only; ZeRO-1/2/3 will work
without errors on the wiring but require a sharding-aware Muon variant
for numerical correctness, added in a follow-up.
Also patch `LayerWiseDistributedOptimizer._allgather_helper` to read
DTensor-backed params via `_local_tensor`, so the layer-wise + FSDP
combination can flatten the local shard rather than the global DTensor.
Add `FSDPZeROTensorParallelMuon`, a TensorParallelMuon subclass that: 1. Extracts the `Shard(0)` local tensor from each gradient DTensor: (`finish_grad_sync` produces a row-shard per DP rank for `optim`, `optim_grads` and `optim_grads_params`). 2. Allgathers the shards across the DP group to reconstruct the TP-local, DP-full gradient matrix. 3. Trims FSDP bucket-padding rows using the DTensor's declared global shape. 4. Delegates Newton-Schulz to the parent class (which handles the TP dimension via `newton_schulz_tp`). 5. Re-shards the orthogonalized result back to a `Shard(0)` DTensor with matching placements so the in-place update in `OrthogonalizedOptimizer.step` does not promote to `Replicate` and trip the global-shape check. The FSDP factory in `_build_megatron_fsdp_emerging_optimizer` now picks `FSDPZeROTensorParallelMuon` for any sharded inner-DP strategy and passes `pg_collection.dp_cp` for dense params and `pg_collection.expt_dp` for expert params (since expert grads reduce-scatter over a different group). "no_shard" continues to use plain `TensorParallelMuon`. DTensor is imported at module scope with a `_HAVE_DTENSOR` guard so the isinstance checks stay cheap and the module still imports on stacks without `torch.distributed.tensor`.
Three phases of tests for the Muon + Megatron-FSDP integration: - Phase 1: `FSDPMuonChainedOptimizer` adapter (single-rank, mock-based). Verifies the step contract – finish_grad_sync -> inner step -> install_optimized_model_weights – and attribute delegation. - Phase 2: `FSDPZeROTensorParallelMuon.orthogonalize` (multi-rank). Asserts the allgather -> Newton-Schulz -> reshard cycle is numerically equivalent to running NS on the full gradient and extracting the local row-shard, including FSDP padding edge cases. Includes a DTensor round-trip test that catches the `p.add_(orthogonalized_dtensor)` placement-promotion bug. - Phase 3: `_build_megatron_fsdp_emerging_optimizer` factory. Confirms the factory dispatches plain `TensorParallelMuon` for `no_shard` and `FSDPZeROTensorParallelMuon` for sharded strategies, and that expert vs. non-expert Muon instances receive `expt_dp` vs. `dp_cp` as their allgather group.
| linear_param_groups = _get_param_groups(model_chunks, config, config_overrides) | ||
|
|
||
| expert_param_groups: List[Dict[str, Any]] = [] | ||
| if not use_layer_wise: |
There was a problem hiding this comment.
Is this because layer-wise <> EP isn't compatible, so we just use Adam? (Gathering expert params is too heavyweight? Also, another complexity is EP-DP sharding, so we need to gather both.)
| if config is None or not config.use_precision_aware_optimizer: | ||
| opt.state[p]["exp_avg"] = torch.zeros_like(p.data) | ||
| opt.state[p]["exp_avg_sq"] = torch.zeros_like(p.data) |
There was a problem hiding this comment.
Also what does this case capture? With MFSDP, I think we need DTensors so we usually do a "dummy step" to get DTensor state based on the DTensor params.
Two examples:
- MCore: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/optimizer/distrib_optimizer.py#L1298 (
_init_optimizer_states_with_dummy_values) - Torch DCP: https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict.py#L609 (
_init_optim_state)
| # Float16OptimizerWithFloat16Params is incompatible with FSDP DTensor | ||
| # params (no .main_grad), so temporarily clear bf16 to prevent the | ||
| # LayerWiseDistributedOptimizer from re-wrapping each sub-optimizer. |
There was a problem hiding this comment.
Either we should extend the logic in LayerWiseDistributedOptimizer or the comment should be simpler, like this:
LayerWiseDistributedOptimizer(config.bf16)wraps the optimizer with a Float16 Megatron optimizer if BF16 is used. Megatron-FSDP only supports DistributedOptimizer and LayerWiseDistributedOptimizer.
Note: Layer-wise appears to be a permutation of optimizer params into different layer-specific groups distributed across DP-CP (or EDP). I am not sure if this is compatible with Megatron-FSDP, since we have un-even shards. Is it possible that, say, non-empty Params are all allocated to a different DP rank, but that rank's equivalent Params are empty, so no optimizer update?
| if hasattr(chunk, "finish_grad_sync") and hasattr(chunk, "module"): | ||
| mfsdp_models.append(chunk.module) |
There was a problem hiding this comment.
Can we just check directly for FullyShardedDataParallel? Many Megatron wrappers satisfy the current if condition!
| return None | ||
|
|
||
| def _is_nonempty_dtensor_param(self, param: torch.Tensor) -> bool: | ||
| dtensor = self._as_dtensor(param) |
There was a problem hiding this comment.
If I didn't mention this before, param should be from the optimizer.param_groups which should already be DTensor right? We really only need to go to Tensor land for any GEMMs or Muon ops.
|
|
||
| grad = p.grad | ||
| if grad is None: | ||
| local_grad = torch.zeros_like(mom_local) |
There was a problem hiding this comment.
Should be empty like the param to get a zero DTensor grad. Ideally, this should be guarded by requires grad and we error out for missing gradients right?
| torch.distributed.all_gather_object( | ||
| gathered_indices, local_boundary_indices, group=self.dp_group | ||
| ) |
There was a problem hiding this comment.
I think we don't need this AG. The current DP rank doesn't need to know what other DP ranks param indices need gathering. We know everything from the current rank's parameter's global shape and local shape.
| # Megatron-FSDP already shards optimizer state itself (ZeRO-1/2/3 | ||
| # via `--data-parallel-sharding-strategy`). Layering | ||
| # `LayerWiseDistributedOptimizer` on top would double-shard and, | ||
| # in practice, trips `TypeErrors` in its constructor call from | ||
| # `_build_megatron_fsdp_emerging_optimizer`. The M-FSDP factory | ||
| # already handles the "distributed" part via `FSDPMuonChainedOptimizer`. |
There was a problem hiding this comment.
The "double shard" here means MFSDP shards tensors unit-wise, while this layer-wise distopt distributes parameters as well. I think I mentioned this above, this needs to be brutally checked and confirmed, and this comment needs to be updated to not be completely vague / meaningless.
| assert args.outer_dp_sharding_strategy == "no_shard", ( | ||
| "Emerging optimizer with Megatron-FSDP does not support HSDP " | ||
| "(--outer-dp-sharding-strategy != no_shard) yet." | ||
| ) |
There was a problem hiding this comment.
This is a red flag. Because we are using DP-CP group, it should work OOTB, the optimizer doesn't care about Megatron-FSDP internal sharding strategies (like DP-inner or DP-outer). It should only see the fully-sharded optimizer state and main parameters on the cumulative DP-CP group. We def need support for HFSDP.
| # Megatron-FSDP itself requires `fsdp_dtensor` (asserted above), so | ||
| # the emerging-optimizer path must accept it here to avoid a | ||
| # contradictory assertion pair. | ||
| assert args.ckpt_format == "fsdp_dtensor", ( | ||
| "Emerging optimizer with Megatron-FSDP requires " | ||
| "--ckpt-format fsdp_dtensor." | ||
| ) |
There was a problem hiding this comment.
What does the optimizer state look like to DCP with layer-wise, if you could share, thanks!
If there is parameter distribution, we need to be careful about writing offsets.
Introduce Muon support to M-FSDP. Currently 1.5×–2.7× as slow compared to an Adam baseline with a 1B–8B DeepSeek-V3 proxy model. Peak memory slightly lower than with Adam (4–7 % less).